from typing import List, Tuple

import numpy as np

from centralized_verification.shields.ahead_of_time_shield import AheadOfTimeShield
from centralized_verification.shields.shield import ShieldResult, AgentResult, AgentUpdate, T


class CentralizedShieldOracle(AheadOfTimeShield):
    def get_initial_shield_state(self, state, initial_joint_obs) -> T:
        return None

    def evaluate_joint_action(self, state, _, proposed_action, __) -> Tuple[ShieldResult, None]:
        action_set = self.get_action_set(state)

        shield_result = [AgentResult(AgentUpdate(action=action)) for action in proposed_action]

        if action_set[proposed_action]:
            return shield_result, None

        # noinspection PyTypeChecker
        priority: List[int] = np.random.permutation(np.arange(len(proposed_action))).tolist()  # Agents in random order

        proposed_action_list = list(proposed_action)
        for agent_to_neuter in priority:  # Set individual agents to the default action until the joint action is safe

            proposed_action_list[agent_to_neuter] = 0
            shield_result[agent_to_neuter] = self.replace_action_agent_result(proposed_action[agent_to_neuter], 0)

            if action_set[tuple(proposed_action_list)]:
                return shield_result, None

        assert False, "Default action for all agents was not safe, this should never happen"
